import os
import json
from tqdm import tqdm
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt


def extract_strictly_increasing_segment(array):
    """
    Finds the longest strictly increasing segment in a 1D NumPy array and outputs a mask.

    Parameters:
    -----------
    array : np.ndarray
        A 1D array of numeric values.

    Returns:
    --------
    tuple:
        - np.ndarray: The longest strictly increasing segment of the array.
        - np.ndarray: A boolean mask indicating the indices of the segment.

    Example:
    --------
    >>> array = np.array([1, 2, 3, 2, 3, 4, 1, 2])
    >>> segment, mask = extract_strictly_increasing_segment(array)
    >>> segment
    array([2, 3, 4])
    >>> mask
    array([False, True, True, True, False, False, False, False])
    """
    start_idx = 0
    max_length = 0
    best_segment = None
    best_mask = None

    current_start = 0
    for i in range(1, len(array)):
        if array[i] <= array[i - 1]:  # Break in the increasing trend
            length = i - current_start
            if length > max_length:  # Update longest segment
                max_length = length
                start_idx = current_start
            current_start = i  # Start a new segment

    # Final check for the last segment
    if len(array) - current_start > max_length:
        start_idx = current_start
        max_length = len(array) - current_start

    best_segment = array[start_idx:start_idx + max_length]

    # Create a mask
    best_mask = np.zeros_like(array, dtype=bool)
    best_mask[start_idx:start_idx + max_length] = True

    return best_segment, best_mask




def combined_plot(list1, vlist1, list2=None, vlist2=None, colores=None, s_path=None, sname=None, save=False, 
                  xaxis=None, yaxis=None, xlabel=None, ylabel=None, title=None):
    """
    This function generates a combined plot of two datasets, allowing for optional scatter and line plots. 
    It provides flexibility in plotting options and customization of axes, labels, and saving capabilities.

    Parameters:
    - list1 (list of arrays): X-axis data for the first dataset.
    - vlist1 (list of arrays): Y-axis data for the first dataset.
    - list2 (list of arrays, optional): X-axis data for the second dataset. Default is None.
    - vlist2 (list of arrays, optional): Y-axis data for the second dataset. Default is None.
    - colores (list of strings, optional): List of colors for the plots. Default is None.
    - s_path (str, optional): Path to save the plot. Required if save=True.
    - sname (str, optional): Name of the saved plot file. Required if save=True.
    - save (bool, optional): Whether to save the plot. If False, the plot is shown. Default is False.
    - xaxis (array, optional): Custom values for the X-axis range. If None, defaults to auto-scaling.
    - yaxis (array, optional): Custom values for the Y-axis range. If None, defaults to auto-scaling.
    - xlabel (str, optional): Label for the X-axis. If None, defaults to 'Fractional convective area'.
    - ylabel (str, optional): Label for the Y-axis. If None, defaults to '[K]'.
    - title (str, optional): Title for the plot. If None, defaults to 'Convective event +- 0.3 ms-1 qn only for +'.
    
    Returns:
    - None: Displays or saves the plot.
    """
    
    # Initialize figure and axis for the plot
    fig, ax = plt.subplots(figsize=(6, 13))
    
    # Plot data from the first dataset using scatter
    for i in range(len(list1)):
        color = colores[i] if colores else 'blue'  # Use provided colors or default to 'blue'
        ax.plot(list1[i], vlist1[i], color=color, label=f'{280 + i * 10} K original')

    # Plot data from the second dataset using a line plot, if provided
    if list2 and vlist2:
        for k in range(len(list2)):
            color = colores[k] if colores else 'red'  # Use provided colors or default to 'red'
            ax.scatter(list2[k], vlist2[k], color=color, label=f'{280 + k * 10} K interpolated', marker='^')

    # Set the X-axis label, or use default if not provided
    ax.set_xlabel(xlabel if xlabel else 'Fractional convective area')
    
    # Set custom X-axis range and labels if provided
    if xaxis is not None:
        ax.set_xlim(min(xaxis), max(xaxis))
        ax.set_xticks(xaxis)
        ax.set_xticklabels([f'{x:.2f}' for x in xaxis])

    # Set the Y-axis label, or use default if not provided
    ax.set_ylabel(ylabel if ylabel else '[K]')
    
    # Set custom Y-axis range and labels if provided
    if yaxis is not None:
        ax.set_ylim(max(yaxis), min(yaxis))  # Flip Y-axis
        ax.set_yticks(yaxis)
        ax.set_yticklabels([f'{y:.4f}' for y in yaxis])
    else:
        # Invert Y-axis for default behavior
        ax.invert_yaxis()

    # Set plot title, or use default if not provided
    ax.set_title(title if title else 'Convective event +- 0.3 ms-1 qn only for +')

    # Display grid and legend
    ax.grid()
    ax.legend()

    # Save the plot if the save parameter is True, otherwise show it
    if save:
        if s_path and sname:
            plt.savefig(os.path.join(s_path, sname))
        else:
            raise ValueError("s_path and sname must be provided if save=True.")
    else:
        plt.show()


def combined_plot_change_per_perK(list1, vlist1, list2=None, vlist2=None, colores=None, title = None,s_path=None, sname=None, save=False):
    """
    This function generates a plot of one or two datasets, plotting the first set as solid lines and 
    the second set (if provided) as dashed lines. It allows for flexible plotting, optional saving, 
    and axis customization.

    Parameters:
    - list1 (list of arrays): X-axis data for the first dataset.
    - vlist1 (list of arrays): Y-axis data for the first dataset.
    - list2 (list of arrays, optional): X-axis data for the second dataset. Default is None.
    - vlist2 (list of arrays, optional): Y-axis data for the second dataset. Default is None.
    - colores (list of strings, optional): List of colors for the plots. Default is None.
    - title (string, optional): Title of the plot. Default is None.
    - s_path (str, optional): Path to save the plot. Required if save=True.
    - sname (str, optional): Name of the saved plot file. Required if save=True.
    - save (bool, optional): Whether to save the plot. If False, the plot is shown. Default is False.
    
    Returns:
    - None: Displays or saves the plot.
    """
    
    # Set up the x-axis range for the plot
    xaxis = np.arange(-0.09, 0.02, 0.01)
    
    # Create the figure and axis for the plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Labels for different temperature intervals
    labels = ['280 -> 290K', '290 -> 300K', '300 -> 310K']
    
    # Plot data from the first dataset as solid lines
    for i in range(len(list1)):
        color = colores[i] if colores else 'blue'  # Use provided color or default to 'blue'
        ax.plot(list1[i], vlist1[i], color=color, label=f'{labels[i]} A_up')

    # Plot data from the second dataset as dashed lines, if provided
    if list2 and vlist2:
        for k in range(len(list2)):
            color = colores[k] if colores else 'red'  # Use provided color or default to 'red'
            ax.plot(list2[k], vlist2[k], color=color, label=f'{labels[k]} MF', linestyle='dashed')

    # Set the x-ticks, x-limits, and custom tick labels
    ax.set_xticks(xaxis)
    ax.set_xlim(min(xaxis), max(xaxis))
    ax.set_xticklabels([f'{x:.2f}' for x in xaxis])

    # Set the axis labels and title
    ax.set_xlabel('1/[K]')
    ax.set_ylabel('[K]')
    ax.set_title(title)

    # Invert the y-axis
    ax.invert_yaxis()

    # Display the legend
    ax.legend()

    # Save the plot if save=True, otherwise display it
    if save:
        if s_path and sname:
            plt.savefig(os.path.join(s_path, sname))
        else:
            raise ValueError("s_path and sname must be provided if save=True.")
    else:
        plt.show()

def one_by_two(x1, x2, y1, y2, labels, xlabel, ylabel, ylabel2, t1, t2, savedir, sname, save):
    """
    Creates a 1x2 subplot figure with shared x-axis. Each subplot has customizable y-axis values, 
    labels, titles, and save options.

    Parameters:
    -----------
    x1 : list of arrays
        List of x-axis data for the first subplot.
    x2 : list of arrays
        List of x-axis data for the second subplot.
    y1 : list of arrays
        List of y-axis data for the first subplot.
    y2 : list of arrays
        List of y-axis data for the second subplot.
    labels : list of str
        List of labels for each dataset to use in the legends.
    xlabel : str
        Label for the x-axis, shared across both subplots.
    ylabel : str
        Label for the y-axis of the first subplot.
    ylabel2 : str
        Label for the y-axis of the second subplot (right side).
    t1 : str
        Title for the first subplot.
    t2 : str
        Title for the second subplot.
    savedir : str
        Directory path where the figure will be saved if `save` is True.
    sname : str
        Name of the file to save the figure as.
    save : bool
        Whether to save the figure (True) or just display it (False).
    
    Returns:
    --------
    None
        Displays the plot, and optionally saves it to the specified directory.
    
    Notes:
    ------
    - The y-axis data `y1` is used for the first subplot and `y2` for the second subplot.
    - The right subplot has its y-axis label on the right side.
    - The y-axes are inverted for both subplots.
    - If `save` is True, the figure is saved as a file in `savedir` with the name `sname`.
    
    Example usage:
    --------------
    one_by_two(x1_data, x2_data, y1_data, y2_data, labels, 'X-Axis Label', 'Y-Axis 1 Label',
               'Y-Axis 2 Label', 'Title 1', 'Title 2', '/path/to/save', 'figure_name.png', True)
    """

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 7), layout='constrained', sharex=True)
    
    # Plot data on the first Axes
    for j in range(len(x1)):
        ax1.plot(x1[j], y1[j], label=f'{labels[j]}')

    ax1.set_xlabel(xlabel)
    ax1.set_ylabel(ylabel)
    ax1.set_title(t1)
    ax1.grid()
    ax1.invert_yaxis()
    ax1.legend()
    
    # Plot data on the second Axes
    for i in range(len(x2)):
        ax2.plot(x2[i], y2[i], label=f'{labels[i]}')
    ax2.set_xlabel(xlabel)
    ax2.yaxis.set_label_position("right")
    ax2.yaxis.tick_right()
    ax2.set_ylabel(ylabel2)
    ax2.grid()
    ax2.set_title(t2)
    ax2.invert_yaxis()
    ax2.legend()

    # Find the minimum length of all arrays
    min_len = min([len(arr) for arr in y1 + y2])
    
    # Truncate arrays to the smallest length and compute vertical average for x1 and x2
    x1_mean = np.mean([arr[:min_len] for arr in x1], axis=0)
    x2_mean = np.mean([arr[:min_len] for arr in x2], axis=0)
    
    # Plot the vertical average as a thicker line in ax1 and ax2
    ax1.plot(x1_mean, y1[0][:min_len], color='black', linewidth=3, label='Mean')
    ax2.plot(x2_mean, y2[0][:min_len], color='black', linewidth=3, label='Mean')

    
    if save:
        plt.savefig(os.path.join(savedir, sname))
    
    # Show the plot
    plt.show()




def mask_precipitation(precip, conv_mask):
    """
    Masks the precipitation data based on the presence of convective events aloft.

    Parameters:
    precip (np.ndarray): A 3D numpy array of shape (time_steps, height, width) representing the precipitation rate.
    conv_mask (np.ndarray): A 4D numpy array of shape (time_steps, pfull, height, width) representing the binary mask 
                            of convective events at various pressure levels.

    The function iterates over each time step and for each grid point in the horizontal dimensions, it checks the vertical
    column in the convective mask. If there is at least one convective event (value 1) above the grid point in the vertical
    dimension, the precipitation rate at that grid point is retained. Otherwise, it is set to NaN.

    Returns:
    np.ndarray: The modified precipitation array with NaNs where there are no convective events.
    """
    prec_masked = np.copy(precip)
    # Loop through each time step
    for t in range(precip.shape[0]):
        # Extract precipitation and convective mask for the current time step
        conv_mask_t = conv_mask[t]
        
        # Create a boolean mask where any 1 in the convective mask along the pressure dimension is True
        has_convective_event = np.any(conv_mask_t, axis=0)
        
        # Use the boolean mask to set precipitation values to NaN where no convective event is detected
        prec_masked[t][~has_convective_event] = np.nan
    
    return prec_masked


def prec_as_mask(precip, fourDfield, prec_bins, pfull,ndiag):
    """
    Masks a diagnostic observable based on whether the precip falls within a certain range and returns 
    the mean values over time, y_t, and x_t for each pressure level in pfull.

    Parameters:
    precip (np.ndarray): 3D array (time_steps, x_t, y_t representing precipitation rate.
    fourDfield (np.ndarray): 4D array (time_steps, pfull, x_t, y_t) representing the observable we seek to understand.
    prec_bins (np.ndarray): 1D array defining the bins for making distributions.
    pfull (np.ndarray): 1D array defining the pressure levels (pfull) corresponding to the 4Dfield's second dimension.
    ndiag (string): A string to be used to label diagnostic being analyzed.

    Returns:
    xarray.DataArray: A DataArrray where the dimensions are pressure levels and precipitation bins, containing the averaged diagnostics.
    """

    # Initialize an empty list to store mean values for each bin
    result_data = []
    
    # Loop through each bin range with a progress bar
    for b in tqdm(range(len(prec_bins) - 1), desc="Processing bins"):
        limL = prec_bins[b]
        limU = prec_bins[b+1]
        bin_label = f"{limL:.3f}-{limU:.3f}"  # Format bin label with 3 significant figures

        # Create a mask for all time steps where precipitation falls within the limits
        prec_bool = np.logical_and(precip > limL, precip <= limU)

        # Broadcast the 2D mask across the vertical levels (pfull axis)
        prec_bool_3d = prec_bool[:, np.newaxis, :, :]

        # Apply the mask to the 4D diagnostic field (time_steps, pfull, x_t, y_t)
        masked_field = np.where(prec_bool_3d, fourDfield, np.nan)

        # Calculate the mean over time, x_t ,and y_t leaving only the pfull dimension
        avg_values_per_level = np.nanmean(masked_field, axis=(0, 2, 3))  # Mean over time (axis 0), x_t (axis 2), and y_t (axis 3)

        # Append the result to the list (will be a column in the final dataset)
        result_data.append(avg_values_per_level)

    # Convert result_data to a 2D array (n_bins x pfull)
    result_data = np.array(result_data).T  # Transpose to get (pfull, bins)

    # Create an xarray.DataArray with pressure levels and bin labels as dimensions
    da = xr.DataArray(
        result_data,  # The actual data array
        dims=["pfull", "prec_bins"],  # Dimensions corresponding to the data
        coords={
            "pfull": pfull,  # Coordinate for the 'pfull' dimension
            "prec_bins": [f"{prec_bins[i]:.3f}-{prec_bins[i+1]:.3f}" for i in range(len(prec_bins) - 1)]  # Bin labels for the 'prec_bins' dimension
        },
        name="mean_" + ndiag  # Optional: If you want to name the DataArray
    )
    
    return da

def mask_checker(precip, fourDfield, limL, limU, pfull, ndiag):
    """
    Masks a 4D diagnostic observable based on whether the precipitation falls within the specified range (limL, limU)
    and returns the masked 4D field.

    Parameters:
    precip (np.ndarray): 3D array (time_steps, y_t,x_t) representing precipitation rate.
    fourDfield (np.ndarray): 4D array (time_steps, pfull,y_t, x_t) representing the observable we seek to understand.
    limL (float): Lower limit of the precipitation range for masking.
    limU (float): Upper limit of the precipitation range for masking.
    pfull (np.ndarray): 1D array defining the pressure levels (pfull) corresponding to the 4Dfield's second dimension.
    ndiag (string): A string to be used to label the diagnostic being analyzed.

    Returns:
    xarray.DataArray: A DataArray where the dimensions are time_steps, pfull, y_t, x_t containing the masked 4D field.
    """
    
    # Create a mask for all time steps where precipitation falls within the limits
    prec_bool = np.logical_and(precip > limL, precip <= limU)

    # Broadcast the 2D mask across the vertical levels (pfull axis)
    prec_bool_3d = prec_bool[:, np.newaxis, :, :]

    # Apply the mask to the 4D diagnostic field (time_steps, pfull, height, width)
    masked_field = np.where(prec_bool_3d, fourDfield, np.nan)

    # Create an xarray.DataArray with the masked 4D field
    da = xr.DataArray(
        masked_field,  # The masked data array
        dims=["time", "pfull", "y_t", "x_t"],  # Dimensions corresponding to the data
        coords={
            "time": np.arange(1,fourDfield.shape[0]+1, 1),  # Coordinate for time steps
            "pfull": pfull,  # Coordinate for the 'pfull' dimension
            "y_t": np.arange(1, fourDfield.shape[2]+1, 1),  # Coordinate for y_t dimension
            "x_t": np.arange(1, fourDfield.shape[3]+1, 1),  # Coordinate for x_t dimension
        
        },
        name="masked_" + ndiag  # Optional: If you want to name the DataArray
    )
    
    return da




def save_to_json(data, dir_path, filename):
    """
    Saves the given data to a JSON file at the specified directory path.

    Parameters:
    data (dict): The data to serialize and save.
    dir_path (str): The directory path where the JSON file will be saved.
    filename (str): The name of the JSON file to save.
    """
    
    # Ensure the directory exists
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    
    # Construct the full file path
    file_path = os.path.join(dir_path, filename)
    
    # Save the data to a JSON file
    with open(file_path, 'w') as json_file:
        json.dump(data, json_file, indent=4)

